#%%
import sys
import re
import traceback
import math
from pathlib import Path
from typing import List, Callable

# --- Required Libraries ---
try:
    import pandas as pd
    import matplotlib.pyplot as plt
    import matplotlib.tri as tri
    import numpy as np
    import scipy.optimize # Added for SimplexDynamicsPlotter
except ImportError as e:
    print(f"Error: Missing required libraries ({e}).")
    print("Please install them using: pip install pandas matplotlib numpy scipy")
    sys.exit(1)

# --- EGT Simplex Plotting Helpers ---
simplex_r0 = np.array([0, 0])
simplex_r1 = np.array([1, 0])
simplex_r2 = np.array([0.5, np.sqrt(3)/2.])
simplex_corners = np.array([simplex_r0, simplex_r1, simplex_r2])
simplex_triangle = tri.Triangulation(simplex_corners[:, 0], simplex_corners[:, 1])

def ba2xy_simplex(ba_coords: np.ndarray) -> np.ndarray:
    """Converts barycentric coordinates to Cartesian (x,y) for a standard simplex."""
    ba_coords = np.array(ba_coords)
    if ba_coords.ndim == 1: # Single point
        return simplex_corners.T.dot(ba_coords)
    else: # Multiple points (array of shape [N,3])
        return simplex_corners.T.dot(ba_coords.T).T

# --- Simplex Dynamics Plotting Class (Ported and adapted) ---
class SimplexDynamicsPlotter:
    # Class attributes for simplex geometry
    r0, r1, r2 = np.array([0,0]), np.array([1,0]), np.array([0.5, np.sqrt(3)/2.])
    corners = np.array([r0, r1, r2])
    triangle = tri.Triangulation(corners[:,0], corners[:,1])
    try:
        refiner = tri.UniformTriRefiner(triangle)
        trimesh = refiner.refine_triangulation(subdiv=5) # Mesh density for contours and quivers
    except Exception as e:
        print(f"Warning: Could not initialize trimesh for SimplexDynamicsPlotter at class definition: {e}")
        refiner = None
        trimesh = None

    def __init__(self, replicator_func: Callable, strategy_labels: List[str], corner_label_fontsize: int = 28): # Increased default
        self.f = replicator_func
        self.strategy_labels = strategy_labels
        self.corner_label_fontsize = corner_label_fontsize

        # Ensure trimesh is initialized if it failed at class level
        if self.trimesh is None and self.refiner is None and self.triangle is not None:
            print("Info: Initializing trimesh for SimplexDynamicsPlotter in __init__.")
            SimplexDynamicsPlotter.refiner = tri.UniformTriRefiner(self.triangle)
            SimplexDynamicsPlotter.trimesh = SimplexDynamicsPlotter.refiner.refine_triangulation(subdiv=5)
        elif self.trimesh is None:
            print("Critical Warning: SimplexDynamicsPlotter.triangle is None, cannot create trimesh.")

        self.calculate_stationary_points() # Find fixed points of the dynamics
        self.calc_direction_and_strength() # Calculate flow field

    def xy2ba(self, x,y):
        """Converts Cartesian (x,y) to barycentric coordinates for this plotter's simplex."""
        detT = (self.corners[1,1]-self.corners[2,1])*(self.corners[0,0]-self.corners[2,0]) + \
               (self.corners[2,0]-self.corners[1,0])*(self.corners[0,1]-self.corners[2,1])
        if abs(detT)<1e-12: return np.array([np.nan]*3) # Avoid division by zero
        l1 = ((self.corners[1,1]-self.corners[2,1])*(x-self.corners[2,0]) + \
              (self.corners[2,0]-self.corners[1,0])*(y-self.corners[2,1]))/detT
        l2 = ((self.corners[2,1]-self.corners[0,1])*(x-self.corners[2,0]) + \
              (self.corners[0,0]-self.corners[2,0])*(y-self.corners[2,1]))/detT
        return np.array([l1,l2,1-l1-l2]) # l3 = 1 - l1 - l2

    def ba2xy(self, ba):
        """Converts barycentric to Cartesian (x,y) for this plotter's simplex."""
        ba=np.array(ba)
        return self.corners.T.dot(ba.T).T if ba.ndim > 1 else self.corners.T.dot(ba)

    def calculate_stationary_points(self, tol=1e-8, margin=0.005):
        """Finds stationary (fixed) points of the replicator dynamics."""
        fp_bary = [] # List to store barycentric coordinates of fixed points
        if self.trimesh is None:
            print("Warning: trimesh not available in calculate_stationary_points. Skipping.")
            self.fixpoints = np.array([])
            return

        # Iterate over points in the refined mesh as starting points for the root finder
        for x_coord,y_coord in zip(self.trimesh.x, self.trimesh.y):
            start_ba = self.xy2ba(x_coord,y_coord) # Convert mesh point to barycentric
            # Skip points too close to the boundary or invalid coordinates
            if np.any(start_ba < margin) or np.any(np.isnan(start_ba)): continue
            try:
                # Find roots of the replicator dynamics function (where dx/dt = 0)
                sol = scipy.optimize.root(lambda vec: self.f(vec,0), start_ba, method="hybr", tol=tol)
                # Check if solution is valid (on simplex, sums to 1)
                if sol.success and math.isclose(np.sum(sol.x),1,abs_tol=1e-3) and \
                   np.all((sol.x > -1e-12)&(sol.x < 1+1e-12)):
                    # Add to list if not a duplicate
                    if not any(np.allclose(sol.x, fp, atol=1e-5) for fp in fp_bary):
                        fp_bary.append(sol.x.tolist())
            except Exception: # Ignore errors from root finding for certain points
                continue
        # Convert found barycentric fixed points to Cartesian for plotting
        self.fixpoints = self.ba2xy(np.array(fp_bary)) if fp_bary else np.array([])
        print(f"Found {len(fp_bary)} fixed points (barycentric): {fp_bary}")

    def calc_direction_and_strength(self):
        """Calculates the direction and strength of flow for the quiver plot."""
        if self.trimesh is None:
            print("Warning: trimesh not available in calc_direction_and_strength. Skipping.")
            self.pvals = np.array([]) # Flow strength
            self.dir_norm_xy = np.array([]) # Normalized flow direction vectors
            return

        # Convert all trimesh points to barycentric coordinates
        bary = np.array([self.xy2ba(x,y) for x,y in zip(self.trimesh.x, self.trimesh.y)])
        # Calculate change dx/dt at each point using the replicator dynamics function
        dir_ba = np.array([self.f(ba,0) if not np.any(np.isnan(ba)) else [0,0,0] for ba in bary])
        
        self.pvals = np.linalg.norm(dir_ba, axis=1) # Magnitude of change (flow strength)

        # Calculate next point for quiver plot arrows (small step in direction of flow)
        next_bary = np.clip(bary + dir_ba * 0.1, 0, 1)
        next_bary_sum = np.sum(next_bary, axis=1, keepdims=True)
        # Normalize to ensure points stay on simplex (sum to 1)
        next_bary = np.divide(next_bary, next_bary_sum, out=np.full_like(next_bary, 1/3.), where=next_bary_sum!=0)
        
        curr_xy = self.ba2xy(bary) # Current points in Cartesian
        next_xy = self.ba2xy(next_bary) # Next points in Cartesian
        self.dir_xy = next_xy - curr_xy # Vector for quiver plot
        norms = np.linalg.norm(self.dir_xy, axis=1)
        # Normalize direction vectors for consistent arrow length in quiver plot
        self.dir_norm_xy = np.divide(self.dir_xy, norms[:,np.newaxis], out=np.zeros_like(self.dir_xy), where=norms[:,np.newaxis]!=0)

    def plot_dynamics_simplex(self, ax, cmap='viridis', colorbar_label_fontsize: int = 22, **kwargs): # Increased default
        """Plots the simplex, flow field, and fixed points."""
        ax.set_facecolor('white') # Background color
        ax.triplot(self.triangle, lw=0.8, c="darkgrey", zorder=1) # Simplex outline

        if self.trimesh is None:
            print("Warning: trimesh not available in plot_dynamics_simplex. Plot will be minimal.")
        else:
            # Plot flow strength contour if available
            if hasattr(self,'pvals') and self.pvals.size > 0:
                contour = ax.tricontourf(self.trimesh, self.pvals, alpha=0.6, cmap=cmap, levels=14, zorder=2, **kwargs)
                cb = plt.colorbar(contour, ax=ax, shrink=0.7)
                cb.set_label("Flow Strength", fontsize=colorbar_label_fontsize)
            # Plot flow direction quiver plot if available
            if hasattr(self,'dir_norm_xy') and self.dir_norm_xy.size > 0 :
                ax.quiver(self.trimesh.x, self.trimesh.y, self.dir_norm_xy[:,0], self.dir_norm_xy[:,1],
                          angles='xy', pivot='mid', scale=20, width=0.004, headwidth=3.5, color='black', zorder=3)

        # Plot fixed points if available
        if hasattr(self,'fixpoints') and self.fixpoints.size > 0:
            ax.scatter(self.fixpoints[:,0], self.fixpoints[:,1], c="red", s=150, marker='o', # Increased size
                       edgecolors='black', lw=1.2, zorder=5, label="Fixed Points")

        # Add strategy labels at corners
        mgn = 0.05
        ax.text(self.r0[0], self.r0[1]-mgn, self.strategy_labels[0], ha='center',va='top',fontsize=self.corner_label_fontsize,weight='bold')
        ax.text(self.r1[0], self.r1[1]-mgn, self.strategy_labels[1], ha='center',va='top',fontsize=self.corner_label_fontsize,weight='bold')
        ax.text(self.r2[0], self.r2[1]+mgn*0.5, self.strategy_labels[2], ha='center',va='bottom',fontsize=self.corner_label_fontsize,weight='bold')

        ax.axis('equal'); ax.axis('off') # Equal aspect ratio, no axes
        ax.set_ylim(ymin=-0.1, ymax=self.r2[1]+0.1); ax.set_xlim(xmin=-0.1, xmax=1.1)


# --- Plot Simplex from Tournament Data Function ---
def plot_simplex_from_tournament_data(
    payoff_matrix_path: Path,
    population_history_path: Path,
    output_dir: Path,
    tournament_run_name: str,
    corner_label_fontsize: int = 28,
    colorbar_label_fontsize: int = 22,
    main_title_fontsize: int = 30,
    legend_fontsize: int = 22,
    trajectory_marker_size: int = 8,
    trajectory_line_width: float = 3.0,
    start_end_marker_size: int = 14,
    show_trajectory: bool = False
    ):
    """
    Generates and saves a simplex plot using data from tournament.py outputs.
    Reads payoff_matrix.csv and (optionally) moran_population_history.csv.
    Trajectory plotting is now controlled by the 'show_trajectory' parameter.
    """
    print(f"\n--- Generating Simplex Plot for Tournament: {tournament_run_name} ---")
    print(f"  Reading payoff matrix: {payoff_matrix_path}")
    if show_trajectory:
        print(f"  Reading population history for trajectory: {population_history_path}")

    # Load Payoff Matrix and Strategy Labels
    try:
        payoff_df = pd.read_csv(payoff_matrix_path)
        strategy_labels = list(payoff_df.columns) # Assumes headers are strategy names
        payoff_matrix = payoff_df.to_numpy()
        if len(strategy_labels) != 3 or payoff_matrix.shape != (3,3):
            print(f"Error: Simplex plot requires 3 strategies. Found {len(strategy_labels)}, matrix shape {payoff_matrix.shape}.")
            return
        print(f"  Strategy labels: {strategy_labels}")
        print(f"  Payoff matrix:\n{payoff_matrix}")
    except FileNotFoundError:
        print(f"Error: Payoff matrix file not found: {payoff_matrix_path}"); return
    except Exception as e:
        print(f"Error loading payoff matrix {payoff_matrix_path}: {e}"); return

    # Load Moran Process Population History (only if show_trajectory is True)
    population_history_data = None
    if show_trajectory:
        try:
            pop_history_df = pd.read_csv(population_history_path)
            # Ensure columns for strategy labels exist in the history file
            if not all(label in pop_history_df.columns for label in strategy_labels):
                print(f"Error: Population history CSV ({population_history_path}) is missing columns "
                      f"corresponding to strategy labels: {strategy_labels}. Found: {list(pop_history_df.columns)}"); return
            population_history_data = pop_history_df[strategy_labels].to_numpy() # Get proportions
            print(f"  Population history loaded. Steps: {len(population_history_data)}")
        except FileNotFoundError:
            print(f"Warning: Population history file not found: {population_history_path}. Plotting without trajectory.")
        except Exception as e:
            print(f"Error loading population history {population_history_path}: {e}")
            # Continue without trajectory if loading fails but show_trajectory was true
            population_history_data = None
            show_trajectory = False # Force trajectory off

    # Define Replicator Dynamics function (can be nested or module-level)
    def replicator_dyn(x_props, t, A_matrix): # x_props: proportions, t: time (unused), A_matrix: payoff matrix
        x_props = np.clip(np.array(x_props), 0, 1) # Ensure props are between 0 and 1
        x_sum = np.sum(x_props)
        # Normalize to sum to 1, handle sum is zero case
        x_props = x_props / x_sum if x_sum > 1e-9 else np.full_like(x_props, 1/len(x_props))
        
        expected_payoffs = A_matrix.dot(x_props) # E_i = sum_j (A_ij * x_j)
        average_population_payoff = x_props.dot(expected_payoffs) # phi = sum_i (x_i * E_i)
        return x_props * (expected_payoffs - average_population_payoff) # dx_i/dt = x_i * (E_i - phi)

    # Create Plot
    fig, ax = plt.subplots(figsize=(14, 13)) # Slightly adjusted figure size for even larger fonts
    plotter = None # Initialize plotter to None
    try:
        plotter = SimplexDynamicsPlotter(
            replicator_func=lambda x, t: replicator_dyn(x, t, payoff_matrix),
            strategy_labels=strategy_labels,
            corner_label_fontsize=corner_label_fontsize # Pass increased font size
        )
        # The fixed_point_marker_size is handled within SimplexDynamicsPlotter's plot_dynamics_simplex
        plotter.plot_dynamics_simplex(ax, colorbar_label_fontsize=colorbar_label_fontsize)
    except Exception as e:
        print(f"Error during simplex dynamics plotting: {e}"); traceback.print_exc(); plt.close(fig); return

    # Plot Moran Trajectory (ONLY IF show_trajectory is True and data is available)
    if show_trajectory and population_history_data is not None and plotter is not None:
        traj_ba = population_history_data # This is already a numpy array of fractions
        # Normalize rows to sum to 1 (should already be, but good practice)
        traj_ba_sum = np.sum(traj_ba, axis=1, keepdims=True)
        traj_ba_normalized = np.divide(traj_ba, traj_ba_sum,
                                       out=np.full_like(traj_ba, 1/3.), # if sum is 0, distribute equally
                                       where=traj_ba_sum != 0)
        
        traj_xy = plotter.ba2xy(traj_ba_normalized) # Convert barycentric to Cartesian
        x_coords, y_coords = traj_xy[:, 0], traj_xy[:, 1]

        # Plot the trajectory line
        ax.plot(x_coords, y_coords, c='magenta', lw=trajectory_line_width, ls='-',
                marker='.', ms=trajectory_marker_size, label='Moran Trajectory', zorder=4)
        
        # Plot start and end markers for the trajectory
        if len(x_coords) > 0:
            ax.plot(x_coords[0], y_coords[0], 'o', c='lime', ms=start_end_marker_size,
                    label='Start', zorder=6, mec='k') # Mark start
            ax.plot(x_coords[-1], y_coords[-1], 's', c='red', ms=start_end_marker_size,
                    label=f'End (Step {len(x_coords)-1})', zorder=6, mec='k') # Mark end
    else:
        print("  Moran trajectory, start, and end markers will not be plotted.")

    # Add Legend, Title, and Save Plot
    handles, legend_labels_list = ax.get_legend_handles_labels()
    if handles:
        by_label = dict(zip(legend_labels_list, handles))
        ax.legend(by_label.values(), by_label.keys(), loc='upper left',
                  bbox_to_anchor=(1.02, 1), borderaxespad=0., fontsize=legend_fontsize)
    
    fig.tight_layout(rect=[0, 0, 0.82, 0.92]) # Adjust rect for larger title/legend
    
    safe_run_name = re.sub(r'[^\w\-]+', '_', tournament_run_name) # Sanitize name for file
    outfile = output_dir / f"Simplex_Plot_Tournament_{safe_run_name}.png"
    try:
        plt.savefig(outfile, dpi=150, bbox_inches='tight'); plt.close(fig)
        print(f"  Saved Simplex plot to: {outfile}")
    except Exception as e:
        print(f"Error saving plot {outfile}: {e}"); plt.close(fig)


# --- Main Execution ---
if __name__ == "__main__":
    # Directory for tournament.py outputs
    BASE_TOURNAMENT_OUTPUT_DIR = Path("scale_tournaments") 
    
    # Output directory for plots generated by this script
    OUTPUT_VIS_DIR = Path("simplex_visualizations")
    OUTPUT_VIS_DIR.mkdir(parents=True, exist_ok=True)

    # --- Configure for the specific tournament run ---
    tournament_run_name_to_plot = "Qwen-1_5B_vs_Qwen-32B_vs_Qwen-72B"
    
    run_data_dir = BASE_TOURNAMENT_OUTPUT_DIR / tournament_run_name_to_plot
    payoff_matrix_file = run_data_dir / "payoff_matrix.csv"
    pop_history_file = run_data_dir / "moran_population_history.csv"

    if run_data_dir.exists() and payoff_matrix_file.exists():
        print(f"Attempting to plot simplex for tournament: {tournament_run_name_to_plot}")
        plot_simplex_from_tournament_data(
            payoff_matrix_path=payoff_matrix_file,
            population_history_path=pop_history_file,
            output_dir=OUTPUT_VIS_DIR,
            tournament_run_name=tournament_run_name_to_plot,
            show_trajectory=False,
        )
    else:
        print(f"\nError: Could not find necessary data files for tournament '{tournament_run_name_to_plot}':")
        if not run_data_dir.exists():
            print(f"  - Tournament data directory not found: {run_data_dir}")
        if not payoff_matrix_file.exists():
            print(f"  - Payoff matrix file not found: {payoff_matrix_file}")
        if not pop_history_file.exists():
             print(f"  - Population history file (moran_population_history.csv) not found at {pop_history_file}. Trajectory cannot be shown even if requested.")
        print(f"  Please ensure at least the payoff matrix exists in the specified location.")

    print("\n--- Simplex Plotting Script Finished ---")
#%%
